# Description:
#   Memory Space Assignment service implementation.

load(
    "//xla:xla.default.bzl",
    "xla_cc_test",
    "xla_py_proto_library",
)
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla/tsl:tsl.bzl", "internal_visibility")
load(
    "//xla/tsl/platform:build_config.bzl",
    "tf_proto_library",
)
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = internal_visibility([":friends"]),
    licenses = ["notice"],
)

package_group(
    name = "friends",
    includes = [
        "//xla:friends",
    ],
)

tf_proto_library(
    name = "memory_space_assignment_proto",
    srcs = ["memory_space_assignment.proto"],
    make_default_target_header_only = True,
)

xla_py_proto_library(
    name = "memory_space_assignment_proto_py_pb2",
    deps = [":memory_space_assignment_proto"],
)

cc_library(
    name = "memory_space_assignment",
    srcs = ["memory_space_assignment.cc"],
    hdrs = ["memory_space_assignment.h"],
    deps = [
        ":algorithm",
        ":allocation",
        ":memory_space_assignment_proto_cc",
        ":options",
        ":simulator",
        ":slice",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/analysis:hlo_dataflow_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:hlo_buffer",
        "//xla/service:hlo_proto_cc",
        "//xla/service:hlo_value",
        "//xla/service/heap_simulator",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/types:span",
        "@tsl//tsl/platform:casts",
    ],
)

xla_test(
    name = "memory_space_assignment_test",
    srcs = ["memory_space_assignment_test.cc"],
    backends = [
        "cpu",
    ],
    tags = [
        "test_migrated_to_hlo_runner_pjrt",
    ],
    deps = [
        ":algorithm",
        ":allocation",
        ":allocation_value",
        ":best_fit_repacker",
        ":buffer_interval_comparator",
        ":cost_analysis",
        ":memory_space_assignment",
        ":memory_space_assignment_proto_cc",
        ":memory_space_assignment_test_base",
        ":options",
        ":prefetch_interval_picker",
        ":repacking",
        ":slice",
        ":testing_utils",
        ":utils",
        "//xla:comparison_util",
        "//xla:literal_util",
        "//xla:shape_tree",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/analysis:hlo_dataflow_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/hlo/utils:hlo_matchers",
        "//xla/service:hlo_buffer",
        "//xla/service:hlo_cost_analysis",
        "//xla/service:hlo_value",
        "//xla/service/cost_modelling:op_cost",
        "//xla/service/heap_simulator",
        "//xla/service/heap_simulator:allocation_block",
        "//xla/tests:test_utils",
        "//xla/tests:xla_internal_test_main",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest",
        "@tsl//tsl/platform:protobuf",
    ],
)

cc_library(
    name = "repacking",
    hdrs = ["repacking.h"],
    deps = [
        "//xla/service/heap_simulator:allocation_block",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/types:span",
    ],
)

cc_library(
    name = "best_fit_repacker",
    srcs = ["best_fit_repacker.cc"],
    hdrs = ["best_fit_repacker.h"],
    deps = [
        ":repacking",
        "//xla:comparison_util",
        "//xla/service/heap_simulator",
        "//xla/service/heap_simulator:allocation_block",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/functional:any_invocable",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
    ],
)

cc_library(
    name = "memory_space_assignment_test_base",
    testonly = True,
    hdrs = ["memory_space_assignment_test_base.h"],
    deps = [
        ":buffer_interval_comparator",
        ":cost_analysis",
        ":memory_space_assignment",
        ":options",
        ":prefetch_interval_picker",
        ":utils",
        "//xla:shape_util",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/transforms/simplifiers:instruction_hoister",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:buffer_value",
        "//xla/service:hlo_buffer",
        "//xla/service:hlo_cost_analysis",
        "//xla/service:hlo_value",
        "//xla/service/cost_modelling:op_cost",
        "//xla/tests:hlo_pjrt_test_base",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_for_library",
    ],
)

cc_library(
    name = "utils",
    srcs = ["utils.cc"],
    hdrs = ["utils.h"],
    deps = [
        ":memory_space_assignment_proto_cc",
        "//xla:shape_util",
        "//xla:util",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:hlo_value",
        "//xla/service/heap_simulator",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_googlesource_code_re2//:re2",
        "@highwayhash",
        "@highwayhash//:arch_specific",
        "@highwayhash//:hh_types",
        "@tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "slice",
    srcs = ["slice.cc"],
    hdrs = [
        "slice.h",
    ],
    deps = [
        ":memory_space_assignment_proto_cc",
        "//xla:shape_util",
        "//xla/service:time_utils",
        "//xla/service/heap_simulator",
        "@com_google_absl//absl/functional:any_invocable",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
    ],
)

xla_cc_test(
    name = "slice_test",
    srcs = ["slice_test.cc"],
    deps = [
        ":slice",
        "//xla/service:time_utils",
        "@com_google_absl//absl/functional:any_invocable",
        "@com_google_absl//absl/log:check",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "allocation",
    srcs = ["allocation.cc"],
    hdrs = [
        "allocation.h",
    ],
    deps = [
        ":memory_space_assignment_proto_cc",
        ":slice",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:hlo_value",
        "//xla/service:time_utils",
        "//xla/service:tuple_util",
        "//xla/service/heap_simulator",
        "//xla/service/heap_simulator:allocation_block",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/functional:function_ref",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@tsl//tsl/platform:casts",
        "@tsl//tsl/platform:errors",
        "@tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "allocation_test",
    srcs = ["allocation_test.cc"],
    deps = [
        ":allocation",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:hlo_value",
        "//xla/service/heap_simulator",
        "//xla/tsl/lib/core:status_test_util",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
        "@tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "tuning_utils",
    srcs = ["tuning_utils.cc"],
    hdrs = ["tuning_utils.h"],
    deps = [
        "//xla/hlo/ir:hlo",
        "//xla/service:hlo_value",
        "//xla/service/heap_simulator",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log:check",
    ],
)

cc_library(
    name = "options",
    srcs = ["options.cc"],
    hdrs = ["options.h"],
    deps = [
        ":allocation_value",
        ":buffer_interval_comparator",
        ":cost_analysis",
        ":memory_space_assignment_proto_cc",
        ":prefetch_interval_picker",
        ":repacking",
        ":slice",
        "//xla:shape_tree",
        "//xla:shape_util",
        "//xla:util",
        "//xla/hlo/ir:hlo",
        "//xla/service:buffer_value",
        "//xla/service:hlo_value",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
    ],
)

cc_library(
    name = "cost_analysis",
    srcs = ["cost_analysis.cc"],
    hdrs = ["cost_analysis.h"],
    deps = [
        "//xla:shape_util",
        "//xla:util",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/analysis:while_loop_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:call_graph",
        "//xla/service:hlo_buffer",
        "//xla/service:hlo_value",
        "//xla/service/cost_modelling:op_cost",
        "//xla/service/heap_simulator",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/functional:function_ref",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/types:span",
        "@tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "cost_analysis_test",
    srcs = ["cost_analysis_test.cc"],
    deps = [
        ":cost_analysis",
        "//xla:shape_util",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:hlo_cost_analysis",
        "//xla/service/cost_modelling:op_cost",
        "//xla/tsl/lib/core:status_test_util",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
        "@tsl//tsl/platform:errors",
        "@tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "simulator",
    srcs = ["simulator.cc"],
    hdrs = ["simulator.h"],
    deps = [
        ":allocation",
        ":cost_analysis",
        "//xla:shape_util",
        "//xla:util",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
    ],
)

xla_cc_test(
    name = "simulator_test",
    srcs = ["simulator_test.cc"],
    deps = [
        ":allocation",
        ":cost_analysis",
        ":simulator",
        "//xla:shape_util",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:hlo_cost_analysis",
        "//xla/service:hlo_value",
        "//xla/service/cost_modelling:op_cost",
        "//xla/service/heap_simulator",
        "//xla/tsl/lib/core:status_test_util",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
        "@tsl//tsl/platform:errors",
        "@tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "prefetch_interval_picker",
    srcs = ["prefetch_interval_picker.cc"],
    hdrs = ["prefetch_interval_picker.h"],
    deps = [
        ":cost_analysis",
        ":memory_space_assignment_proto_cc",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:hlo_proto_cc",
        "//xla/service:hlo_value",
        "//xla/service/heap_simulator",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@tsl//tsl/platform:logging",
    ],
)

cc_library(
    name = "testing_utils",
    testonly = True,
    hdrs = ["testing_utils.h"],
    deps = [
        ":cost_analysis",
        "//xla:shape_util",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:call_graph",
        "//xla/service:hlo_cost_analysis",
        "//xla/service/cost_modelling:op_cost",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/types:span",
        "@tsl//tsl/platform:statusor",
    ],
)

cc_library(
    name = "memory_bound_loop_optimizer",
    srcs = ["memory_bound_loop_optimizer.cc"],
    hdrs = ["memory_bound_loop_optimizer.h"],
    deps = [
        ":allocation",
        ":cost_analysis",
        ":memory_space_assignment_proto_cc",
        ":options",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:buffer_value",
        "//xla/service:hlo_buffer",
        "//xla/service:hlo_proto_cc",
        "//xla/service:hlo_value",
        "//xla/service/heap_simulator",
        "//xla/service/heap_simulator:allocation_block",
        "//xla/tsl/platform:errors",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/base:nullability",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/log:die_if_null",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/types:span",
    ],
)

xla_cc_test(
    name = "memory_bound_loop_optimizer_test",
    srcs = ["memory_bound_loop_optimizer_test.cc"],
    deps = [
        ":allocation",
        ":buffer_interval_comparator",
        ":cost_analysis",
        ":memory_bound_loop_optimizer",
        ":memory_space_assignment",
        ":memory_space_assignment_proto_cc",
        ":options",
        ":prefetch_interval_picker",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:buffer_value",
        "//xla/service:hlo_cost_analysis",
        "//xla/service:hlo_value",
        "//xla/service/cost_modelling:op_cost",
        "//xla/tsl/lib/core:status_test_util",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
        "@com_googlesource_code_re2//:re2",
        "@tsl//tsl/platform:errors",
        "@tsl//tsl/platform:statusor",
        "@tsl//tsl/platform:test",
    ],
)

cc_library(
    name = "algorithm",
    srcs = ["algorithm.cc"],
    hdrs = ["algorithm.h"],
    deps = [
        ":allocation",
        ":allocation_value",
        ":buffer_interval_comparator",
        ":cost_analysis",
        ":memory_bound_loop_optimizer",
        ":memory_space_assignment_proto_cc",
        ":options",
        ":repacking",
        ":slice",
        ":tuning_utils",
        ":utils",
        "//xla:debug_options_flags",
        "//xla:shape_tree",
        "//xla:shape_util",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/analysis:alias_info",
        "//xla/hlo/analysis:hlo_alias_analysis",
        "//xla/hlo/analysis:hlo_dataflow_analysis",
        "//xla/hlo/analysis:hlo_operand_index",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/transforms:memory_space_propagation",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:buffer_value",
        "//xla/service:call_graph",
        "//xla/service:computation_layout",
        "//xla/service:hlo_buffer",
        "//xla/service:hlo_proto_cc",
        "//xla/service:hlo_value",
        "//xla/service:time_utils",
        "//xla/service/heap_simulator",
        "//xla/service/heap_simulator:allocation_block",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:btree",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/functional:any_invocable",
        "@com_google_absl//absl/hash",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
    ],
)

cc_library(
    name = "buffer_interval_comparator",
    srcs = ["buffer_interval_comparator.cc"],
    hdrs = ["buffer_interval_comparator.h"],
    deps = [
        ":cost_analysis",
        ":memory_space_assignment_proto_cc",
        ":utils",
        "//xla:shape_util",
        "//xla:util",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_live_range",
        "//xla/service:buffer_value",
        "//xla/service:hlo_value",
        "//xla/service/heap_simulator",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
    ],
)

cc_library(
    name = "allocation_value",
    srcs = ["allocation_value.cc"],
    hdrs = ["allocation_value.h"],
    deps = [
        ":allocation",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/service:hlo_value",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
    ],
)

xla_cc_test(
    name = "prefetch_interval_picker_test",
    srcs = ["prefetch_interval_picker_test.cc"],
    deps = [
        ":cost_analysis",
        ":prefetch_interval_picker",
        ":testing_utils",
        "//xla:shape_util",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service:hlo_cost_analysis",
        "//xla/service:hlo_value",
        "//xla/service/cost_modelling:op_cost",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
        "@tsl//tsl/platform:statusor",
    ],
)

xla_cc_test(
    name = "best_fit_repacker_test",
    srcs = ["best_fit_repacker_test.cc"],
    deps = [
        ":best_fit_repacker",
        "//xla:comparison_util",
        "//xla/service/heap_simulator",
        "//xla/service/heap_simulator:allocation_block",
        "//xla/tests:xla_internal_test_main",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest",
        "@tsl//tsl/platform:test",
    ],
)
